### FIGURE 5 ###
library(ggplot2)
library(xgboost) 
library(maps)
library(dplyr)
library(ggnewscale)


plot_list = list()

df_imp <- data.frame(matrix(nrow= 0, ncol=4)) 

for (t in c("Cas","CBASS","Gabija","RosmerTA","R-M","Gao_Qat","PD-T4-5","PD-T7-5","Ssp")){
  print(t)
  # Loading ML models
  xgb_model <- readRDS(paste0("xgb_model_defensome_",t,".rds"))
  print(length(xgb_model$model$feature_names))
  
  # Get importances and plotting
  imp <- xgb.importance(model = xgb_model$model)
  imp <- data.frame(imp)
  ph_vector <- c("DgiS1","P1003","P1017","P1055","P1231","P1233","P1311","P1023","P1165","P1068","P1184","P1074","P1372","P1331","P1076",
                  "P1012","P1179","P1151","P1054","P1251","P1371","P1041","P1050","P1011","P1002","P1009","P1240","P1471","P1075","P1105","P1339",
                  "P18","P2140","P1035","P1161","P1255","P1491","P1661","P1140","P1167","P1059","P1115","P1114","P1031","P1310","P1183","P1107")
  phages_ex <- setdiff(ph_vector,imp$Feature)
  if (length(phages_ex) > 0) {
    new_phages <- data.frame(Feature = phages_ex, Gain = 0, Cover = 0, Frequency= 0)
    imp <- rbind(imp, new_phages)
  }
  print(max(imp$Gain))
  imp_types <- imp
  imp_types$types <- t
  df_imp <- rbind(df_imp,imp_types)
  ab_ph_rel <- colSums(xgb_model$pres[,-48])/nrow(xgb_model$pres)   # Calculate the relative abundance of each prophage
  max_ab <- max(ab_ph_rel) # Max relative abundance to set the upper limit in barplots
  print(max_ab)
  treshold_ab <- max_ab*0.4
  #for (i in 1:length(ab_ph_rel)){
  #  if (ab_ph_rel[i] < 0.01){
  #    ab_ph_rel[i] = ab_ph_rel[i]*5 - max_ab # Transformation of abundance values to get negative and positive values in barplots
  #  }
  #}

  imp_ab3 <- merge(data.frame(Feature=names(ab_ph_rel),Rel_Ab= ab_ph_rel),imp, by="Feature")

  
  #partial_df_ab_imp <- select(imp_ab3, c("Feature","Gain","Rel_Ab"))
  #partial_df_ab_imp$defsys <- t
  
  
  imp_ab3 <- data.frame(lapply(imp_ab3, function(x) { # Rename prophages in the figure
    gsub("Phage_","P",x)
  }))
  
  imp_ab3[,c(2,3,4,5)] <- apply(imp_ab3[,c(2,3,4,5)], 2, function(x) as.numeric(x))
  imp_ab3$Gain <- ifelse(imp_ab3$Gain < 0.01, 0, imp_ab3$Gain)
  #imp_ab3 <- imp_ab3 %>% mutate(zero = ifelse(Rel_Ab > treshold_ab, "1", "0")) # Split into pres and aus phages
  imp_ab_pres <- imp_ab3[imp_ab3$Rel_Ab >= 0.1,]
  imp_ab_aus <- imp_ab3
  imp_ab_pres <- imp_ab3
  imp_ab_aus$Gain <- ifelse(imp_ab3$Rel_Ab < 0.1, imp_ab3$Gain, 0)
  imp_ab_pres$Gain <- ifelse(imp_ab3$Rel_Ab >= 0.1, imp_ab3$Gain, 0)
  imp_ab3$Rel_Ab2 <- ifelse(imp_ab3$Rel_Ab < 0.1, imp_ab3$Rel_Ab -0.3, imp_ab3$Rel_Ab)
  
  my_colors=c(colorRampPalette(c("#ff3636ff","white"))(10),colorRampPalette(c("white","green"))(30))
  # Drawing plots
  if (t=="Ssp"){    # Last barplot with x-axis
   gp <- ggplot(mapping = aes(x=Feature, y= Gain))+geom_col(data= imp_ab_aus,aes(fill=Rel_Ab), show.legend = FALSE)+ scale_fill_gradient(low="red",high="white", limits=c(0,0.1))+ggnewscale::new_scale_fill()+
     geom_col(data = imp_ab_pres, aes(fill=Rel_Ab), show.legend = FALSE) + scale_fill_gradient2(low="white", mid= "#76EE00", high = "#458B00", midpoint= 0.4, limits=c(0.1,1))+ labs(y= " ", x = "Prophages")+
     theme_minimal()+theme(axis.title.x= element_text(size = 14), axis.title.y = element_text(size = 14), axis.text.x = element_text(angle=90,hjust=1,size = 15, vjust= 0.5), axis.text.y = element_text(size=15))+ylim(0,0.6)#scale_y_continuous(limits=c(0.1,1))#,breaks = c(-max_ab,0,1), labels=c("0","0.01","1"))
    # gp<-ggplot(imp_ab3,aes(x=Feature, y= Gain, fill= Rel_Ab))+ labs(x= "Prophages", y= " ")+ 
  #    geom_bar(stat="identity",show.legend = FALSE, color="#BFBFBF")+ scale_fill_gradient2(low="#e20000ff",mid="#c7ffc7ff",high="#008e00ff",midpoint=0.4, limits= c(0,1))+ 
  #    theme_minimal()+theme(axis.title.x= element_text(size = 14), axis.title.y = element_text(size = 14), axis.text.x = element_text(angle=90,hjust=1,size = 15, vjust= 0.5), axis.text.y = element_text(size=15))+ylim(0,0.6)#scale_y_continuous(limits=c(0.1,1))#,breaks = c(-max_ab,0,1), labels=c("0","0.01","1"))
  }else if(t=="PD-T4-5"){   # Mid barplot with y-axis
    gp <- ggplot(mapping = aes(x=Feature, y= Gain))+geom_col(data= imp_ab_aus,aes(fill=Rel_Ab),show.legend = FALSE)+ scale_fill_gradient(low="red",high="white", limits=c(0,0.1))+ggnewscale::new_scale_fill()+
      geom_col(data = imp_ab_pres, aes(fill=Rel_Ab),show.legend = FALSE) + scale_fill_gradient2(low="white", mid= "#76EE00", high = "#458B00", midpoint= 0.4, limits=c(0.1,1))+ labs(y= "Importance")+
      theme_minimal()+ theme(axis.title.x= element_blank(), axis.text.x = element_blank(), axis.title.y= element_text(size = 14), axis.text.y = element_text(size=15)) + ylim(0,0.6)
  }else{      # Rest of the barplots
    gp <- ggplot(mapping = aes(x=Feature, y= Gain))+geom_col(data= imp_ab_aus,aes(fill=Rel_Ab),show.legend = FALSE)+ scale_fill_gradient(low="red",high="white", limits=c(0,0.1))+labs(fill="Absence")+ggnewscale::new_scale_fill()+
      geom_col(data = imp_ab_pres, aes(fill=Rel_Ab),show.legend = FALSE) + scale_fill_gradient2(low="white", mid= "#76EE00", high = "#458B00", midpoint= 0.4, limits=c(0.1,1))+ labs(y=" ", fill="Presence")+
      theme_minimal()+theme(axis.title.x= element_blank(),axis.text.x=element_blank(), axis.title.y= element_text(size = 14), legend.position="right", legend.text = element_text(size=14), legend.key.size = unit(8, 'mm'), axis.text.y = element_text(size=15)) + ylim(0,0.6)
  }
  
  plot_list <- append(plot_list, assign(t,gp)) # Add each plot with its name to the plot list
  
}
  


library(ggpubr)
library(cowplot)

# Put all barplots together
imp_plot <- plot_grid(Cas, CBASS, Gao_Qat, `R-M`, `PD-T4-5`, `PD-T7-5`, Gabija, RosmerTA, Ssp,  labels = c("Cas","CBASS", "Gao_Qat", "RM","PD-T4-5","PD-T7-5","Gabija","RosmerTA","Ssp"), label_size = 15,vjust=c(1,0.7,0.7,1,0.7,0.7,0.7,0.7),hjust=c(-2.3,-1.2,-1.2,-2.3,-1.2,-1.2,-1.7,-1),
                      ncol=1, nrow=9, rel_heights = c(1,1,1,1,1,1,1,1,2))




# Combine with the legend of a single plot
A <- plot_grid(imp_plot,lg,rel_widths=c(15,1),nrow=1, ncol= 2)

pdf("fig5A.pdf", width = 21, height = 18.5)
print(A)
dev.off()
# PANEL B

library(ggVennDiagram)
library(eulerr)
library(nVennR)
#library(colortools)
library(ggpubr)

phages <- list(l1 = c("DgiS1", "1012", "1031", "1054", "1074", "1371"),
               l2 = c("1017", "1023", "1059", "1009", "1011"))

phages <- list(l1 = c("DgiS1", "1012", "1031", "1054", "1059", "1331", "1074", "1151", "1371"),
               l2 = c("1017", "1023", "1002", "1009", "1011", "1055", "1140", "1167", "1311"))

phages <- list(l1 = c("DgiS1", "1012", "1031", "1054", "1074", "1371"),
               l2 = c( "1009", "1011", "1017", "1023", "1055", "1311"))

l1 <- list()
l2 <- list()
for (l in 1:2) {
  for (ph in phages[[l]]) {
    setwd(paste0("/home/brown/Documentos/Proyectos/defensome_aba/spacers/spacers/p", ph))
    
    p1 <- readLines("phage.ab")   # Strains with each prophage
    p2 <- readLines("/home/brown/Documentos/Proyectos/defensome_aba/crispr.ab")  # Strains with CRISPR
    p3 <- readLines("spacer.ab")    # Strains with spacers
    p4 <- readLines("phage_spacer.ab")  # Strains with prophage and spacer
    
    ab_list <- list(Phage = p1,
                    Cas = p2,
                    Sp = p3)
    
    euler_plot <- euler(ab_list)
    
    cline = 1
    ccol = "black"
    if(ph == "DgiS1" | ph == "1017" | ph == "1023" | ph == "1031") {
      cline = 2
      ccol ="red"
    }
    if(l == 1) { 
      l1[[ph]]<- ggarrange(plot(euler_plot,
                                fills = list(fill = c("#B4464B", "#4682B4", "#B4AF46"), alpha = 0.4),
                                labels = F, 
                                #labels = list(font = 2, cex = 2, col = c("#B4464B", "#4682B4", "#B4AF46")),  cex = 2. lwd = 2
                                quantities = list(cex = 0.8), # 1.4
                                lty = 0)) + theme(plot.background = element_rect(color = ccol, linewidth = cline))
    } else {
      l2[[ph]]<- ggarrange(plot(euler_plot,
                                fills = list(fill = c("#B4464B", "#4682B4", "#B4AF46"), alpha = 0.4),
                                labels = F, 
                                #labels = list(font = 2, cex = 2, col = c("#B4464B", "#4682B4", "#B4AF46")),  cex = 2. lwd = 2
                                quantities = list(cex = 0.8), # 1.4
                                lty = 0)) + theme(plot.background = element_rect(color = ccol, linewidth = cline))
    }
  }
}

f1 <- ggarrange(plotlist = l1,
                labels = paste0("Phage ", phages[[1]]),
                nrow = 1, hjust = 0)
f2 <- ggarrange(plotlist = l2,
                labels = paste0("Phage ", phages[[2]]),
                nrow = 1, hjust = 0)
f <- ggarrange(f2, f1, nrow = 2)

p1 <- data.frame(c1 = c(1, 2, 3), c2 = c(1, 2, 3), c3 = c("with CRISPR-Cas", "with phages", "with spacers"))
p <- ggplot(p1, aes(x = c1, fill = as.factor(c3))) + geom_bar(stat = "count") +
  scale_fill_manual(values = alpha(c("#B4464B", "#4682B4", "#B4AF46"), 0.4)) + theme(legend.position= "top") + labs(fill= "Genomes")
leg <- get_legend(p)
ggarrange(leg)

B <- ggarrange(leg,f, ncol = 1, heights = c(0.145,1))


fig6 <- plot_grid(A, B, ncol = 1, labels = c("A","B"), rel_heights = c(5,2))

pdf("fig6.pdf", width = 12, height =18, paper = "special")
plot_grid(A, B, ncol = 1, labels = c("A","B"), rel_heights = c(7,3))
dev.off()
  